mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-10 19:46:36 +03:00
Create a self-contained MTProtoState
This frees us from using entire Session objects in something that's supposed to just send and receive items from the net.
This commit is contained in:
parent
cc5753137c
commit
adfe861e9f
|
@ -1,12 +1,7 @@
|
||||||
"""Various helpers not related to the Telegram API itself"""
|
"""Various helpers not related to the Telegram API itself"""
|
||||||
import os
|
import os
|
||||||
import struct
|
|
||||||
from hashlib import sha1, sha256
|
from hashlib import sha1, sha256
|
||||||
|
|
||||||
from .crypto import AES
|
|
||||||
from .errors import SecurityError, BrokenAuthKeyError
|
|
||||||
from .extensions import BinaryReader
|
|
||||||
|
|
||||||
|
|
||||||
# region Multiple utilities
|
# region Multiple utilities
|
||||||
|
|
||||||
|
@ -27,77 +22,6 @@ def ensure_parent_dir_exists(file_path):
|
||||||
# region Cryptographic related utils
|
# region Cryptographic related utils
|
||||||
|
|
||||||
|
|
||||||
def pack_message(session, message):
|
|
||||||
"""Packs a message following MtProto 2.0 guidelines"""
|
|
||||||
# See https://core.telegram.org/mtproto/description
|
|
||||||
data = struct.pack('<qq', session.salt, session.id) + bytes(message)
|
|
||||||
padding = os.urandom(-(len(data) + 12) % 16 + 12)
|
|
||||||
|
|
||||||
# Being substr(what, offset, length); x = 0 for client
|
|
||||||
# "msg_key_large = SHA256(substr(auth_key, 88+x, 32) + pt + padding)"
|
|
||||||
msg_key_large = sha256(
|
|
||||||
session.auth_key.key[88:88 + 32] + data + padding).digest()
|
|
||||||
|
|
||||||
# "msg_key = substr (msg_key_large, 8, 16)"
|
|
||||||
msg_key = msg_key_large[8:24]
|
|
||||||
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, True)
|
|
||||||
|
|
||||||
key_id = struct.pack('<Q', session.auth_key.key_id)
|
|
||||||
return key_id + msg_key + AES.encrypt_ige(data + padding, aes_key, aes_iv)
|
|
||||||
|
|
||||||
|
|
||||||
def unpack_message(session, body):
|
|
||||||
"""Unpacks a message following MtProto 2.0 guidelines"""
|
|
||||||
# See https://core.telegram.org/mtproto/description
|
|
||||||
if len(body) < 8:
|
|
||||||
if body == b'l\xfe\xff\xff':
|
|
||||||
raise BrokenAuthKeyError()
|
|
||||||
else:
|
|
||||||
raise BufferError("Can't decode packet ({})".format(body))
|
|
||||||
|
|
||||||
key_id = struct.unpack('<Q', body[:8])[0]
|
|
||||||
if key_id != session.auth_key.key_id:
|
|
||||||
raise SecurityError('Server replied with an invalid auth key')
|
|
||||||
|
|
||||||
msg_key = body[8:24]
|
|
||||||
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, False)
|
|
||||||
data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv))
|
|
||||||
|
|
||||||
data.read_long() # remote_salt
|
|
||||||
if data.read_long() != session.id:
|
|
||||||
raise SecurityError('Server replied with a wrong session ID')
|
|
||||||
|
|
||||||
remote_msg_id = data.read_long()
|
|
||||||
remote_sequence = data.read_int()
|
|
||||||
msg_len = data.read_int()
|
|
||||||
message = data.read(msg_len)
|
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines
|
|
||||||
# Sections "checking sha256 hash" and "message length"
|
|
||||||
if msg_key != sha256(
|
|
||||||
session.auth_key.key[96:96 + 32] + data.get_bytes()).digest()[8:24]:
|
|
||||||
raise SecurityError("Received msg_key doesn't match with expected one")
|
|
||||||
|
|
||||||
return message, remote_msg_id, remote_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def calc_key(auth_key, msg_key, client):
|
|
||||||
"""
|
|
||||||
Calculate the key based on Telegram guidelines
|
|
||||||
for MtProto 2, specifying whether it's the client or not.
|
|
||||||
"""
|
|
||||||
# https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
|
||||||
x = 0 if client else 8
|
|
||||||
|
|
||||||
sha256a = sha256(msg_key + auth_key[x: x + 36]).digest()
|
|
||||||
sha256b = sha256(auth_key[x + 40:x + 76] + msg_key).digest()
|
|
||||||
|
|
||||||
aes_key = sha256a[:8] + sha256b[8:24] + sha256a[24:32]
|
|
||||||
aes_iv = sha256b[:8] + sha256a[8:24] + sha256b[24:32]
|
|
||||||
|
|
||||||
return aes_key, aes_iv
|
|
||||||
|
|
||||||
|
|
||||||
def generate_key_data_from_nonce(server_nonce, new_nonce):
|
def generate_key_data_from_nonce(server_nonce, new_nonce):
|
||||||
"""Generates the key data corresponding to the given nonce"""
|
"""Generates the key data corresponding to the given nonce"""
|
||||||
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
|
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
|
||||||
|
|
|
@ -3,13 +3,13 @@ import logging
|
||||||
|
|
||||||
from . import MTProtoPlainSender, authenticator
|
from . import MTProtoPlainSender, authenticator
|
||||||
from .connection import ConnectionTcpFull
|
from .connection import ConnectionTcpFull
|
||||||
from .. import helpers, utils
|
from .. import utils
|
||||||
from ..errors import (
|
from ..errors import (
|
||||||
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
|
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
|
||||||
rpc_message_to_error
|
rpc_message_to_error
|
||||||
)
|
)
|
||||||
from ..extensions import BinaryReader
|
from ..extensions import BinaryReader
|
||||||
from ..tl import TLMessage, MessageContainer, GzipPacked
|
from ..tl import MessageContainer, GzipPacked
|
||||||
from ..tl.functions.auth import LogOutRequest
|
from ..tl.functions.auth import LogOutRequest
|
||||||
from ..tl.types import (
|
from ..tl.types import (
|
||||||
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
|
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
|
||||||
|
@ -39,8 +39,8 @@ class MTProtoSender:
|
||||||
A new authorization key will be generated on connection if no other
|
A new authorization key will be generated on connection if no other
|
||||||
key exists yet.
|
key exists yet.
|
||||||
"""
|
"""
|
||||||
def __init__(self, session, retries=5):
|
def __init__(self, state, retries=5):
|
||||||
self.session = session
|
self.state = state
|
||||||
self._connection = ConnectionTcpFull()
|
self._connection = ConnectionTcpFull()
|
||||||
self._ip = None
|
self._ip = None
|
||||||
self._port = None
|
self._port = None
|
||||||
|
@ -171,21 +171,17 @@ class MTProtoSender:
|
||||||
# a `Future` that you need to further ``await`` instead of the
|
# a `Future` that you need to further ``await`` instead of the
|
||||||
# currently double ``await (await send())``?
|
# currently double ``await (await send())``?
|
||||||
if utils.is_list_like(request):
|
if utils.is_list_like(request):
|
||||||
if not ordered:
|
|
||||||
# False-y values must be None to do after_id = ordered and ...
|
|
||||||
ordered = None
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
after_id = None
|
after = None
|
||||||
for r in request:
|
for r in request:
|
||||||
message = TLMessage(self.session, r, after_id=after_id)
|
message = self.state.create_message(r, after=after)
|
||||||
self._pending_messages[message.msg_id] = message
|
self._pending_messages[message.msg_id] = message
|
||||||
after_id = ordered and message.msg_id
|
|
||||||
await self._send_queue.put(message)
|
await self._send_queue.put(message)
|
||||||
result.append(message.future)
|
result.append(message.future)
|
||||||
|
after = ordered and message
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
message = TLMessage(self.session, request)
|
message = self.state.create_message(request)
|
||||||
self._pending_messages[message.msg_id] = message
|
self._pending_messages[message.msg_id] = message
|
||||||
await self._send_queue.put(message)
|
await self._send_queue.put(message)
|
||||||
return message.future
|
return message.future
|
||||||
|
@ -215,13 +211,13 @@ class MTProtoSender:
|
||||||
raise _last_error
|
raise _last_error
|
||||||
|
|
||||||
__log__.debug('Connection success!')
|
__log__.debug('Connection success!')
|
||||||
if self.session.auth_key is None:
|
if self.state.auth_key is None:
|
||||||
_last_error = SecurityError()
|
_last_error = SecurityError()
|
||||||
plain = MTProtoPlainSender(self._connection)
|
plain = MTProtoPlainSender(self._connection)
|
||||||
for retry in range(1, self._retries + 1):
|
for retry in range(1, self._retries + 1):
|
||||||
try:
|
try:
|
||||||
__log__.debug('New auth_key attempt {}...'.format(retry))
|
__log__.debug('New auth_key attempt {}...'.format(retry))
|
||||||
self.session.auth_key, self.session.time_offset =\
|
self.state.auth_key, self.state.time_offset =\
|
||||||
await authenticator.do_authentication(plain)
|
await authenticator.do_authentication(plain)
|
||||||
except (SecurityError, AssertionError) as e:
|
except (SecurityError, AssertionError) as e:
|
||||||
_last_error = e
|
_last_error = e
|
||||||
|
@ -268,13 +264,14 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
while self._user_connected and not self._reconnecting:
|
while self._user_connected and not self._reconnecting:
|
||||||
if self._pending_ack:
|
if self._pending_ack:
|
||||||
await self._send_queue.put(TLMessage(
|
await self._send_queue.put(self.state.create_message(
|
||||||
self.session, MsgsAck(list(self._pending_ack))))
|
MsgsAck(list(self._pending_ack))
|
||||||
|
))
|
||||||
self._pending_ack.clear()
|
self._pending_ack.clear()
|
||||||
|
|
||||||
messages = await self._send_queue.get()
|
messages = await self._send_queue.get()
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
message = TLMessage(self.session, MessageContainer(messages))
|
message = self.state.create_message(MessageContainer(messages))
|
||||||
self._pending_messages[message.msg_id] = message
|
self._pending_messages[message.msg_id] = message
|
||||||
self._pending_containers.append(message)
|
self._pending_containers.append(message)
|
||||||
else:
|
else:
|
||||||
|
@ -283,7 +280,7 @@ class MTProtoSender:
|
||||||
|
|
||||||
__log__.debug('Packing {} outgoing message(s)...'
|
__log__.debug('Packing {} outgoing message(s)...'
|
||||||
.format(len(messages)))
|
.format(len(messages)))
|
||||||
body = helpers.pack_message(self.session, message)
|
body = self.state.pack_message(message)
|
||||||
|
|
||||||
while not any(m.future.cancelled() for m in messages):
|
while not any(m.future.cancelled() for m in messages):
|
||||||
try:
|
try:
|
||||||
|
@ -333,8 +330,7 @@ class MTProtoSender:
|
||||||
# TODO Check salt, session_id and sequence_number
|
# TODO Check salt, session_id and sequence_number
|
||||||
__log__.debug('Decoding packet of {} bytes...'.format(len(body)))
|
__log__.debug('Decoding packet of {} bytes...'.format(len(body)))
|
||||||
try:
|
try:
|
||||||
message, remote_msg_id, remote_seq =\
|
message = self.state.unpack_message(body)
|
||||||
helpers.unpack_message(self.session, body)
|
|
||||||
except (BrokenAuthKeyError, BufferError) as e:
|
except (BrokenAuthKeyError, BufferError) as e:
|
||||||
# The authorization key may be broken if a message was
|
# The authorization key may be broken if a message was
|
||||||
# sent malformed, or if the authkey truly is corrupted.
|
# sent malformed, or if the authkey truly is corrupted.
|
||||||
|
@ -346,7 +342,7 @@ class MTProtoSender:
|
||||||
# TODO Is it possible to detect malformed messages vs
|
# TODO Is it possible to detect malformed messages vs
|
||||||
# an actually broken authkey?
|
# an actually broken authkey?
|
||||||
__log__.warning('Broken authorization key?: {}'.format(e))
|
__log__.warning('Broken authorization key?: {}'.format(e))
|
||||||
self.session.auth_key = None
|
self.state.auth_key = None
|
||||||
asyncio.ensure_future(self._reconnect())
|
asyncio.ensure_future(self._reconnect())
|
||||||
break
|
break
|
||||||
except SecurityError as e:
|
except SecurityError as e:
|
||||||
|
@ -357,28 +353,27 @@ class MTProtoSender:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
with BinaryReader(message) as reader:
|
with BinaryReader(message.body) as reader:
|
||||||
await self._process_message(
|
await self._process_message(message, reader)
|
||||||
remote_msg_id, remote_seq, reader)
|
|
||||||
except TypeNotFoundError as e:
|
except TypeNotFoundError as e:
|
||||||
__log__.warning('Could not decode received message: {}, '
|
__log__.warning('Could not decode received message: {}, '
|
||||||
'raw bytes: {!r}'.format(e, message))
|
'raw bytes: {!r}'.format(e, message))
|
||||||
|
|
||||||
# Response Handlers
|
# Response Handlers
|
||||||
|
|
||||||
async def _process_message(self, msg_id, seq, reader):
|
async def _process_message(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Adds the given message to the list of messages that must be
|
Adds the given message to the list of messages that must be
|
||||||
acknowledged and dispatches control to different ``_handle_*``
|
acknowledged and dispatches control to different ``_handle_*``
|
||||||
method based on its type.
|
method based on its type.
|
||||||
"""
|
"""
|
||||||
self._pending_ack.add(msg_id)
|
self._pending_ack.add(message.msg_id)
|
||||||
code = reader.read_int(signed=False)
|
code = reader.read_int(signed=False)
|
||||||
reader.seek(-4)
|
reader.seek(-4)
|
||||||
handler = self._handlers.get(code, self._handle_update)
|
handler = self._handlers.get(code, self._handle_update)
|
||||||
await handler(msg_id, seq, reader)
|
await handler(message, reader)
|
||||||
|
|
||||||
async def _handle_rpc_result(self, msg_id, seq, reader):
|
async def _handle_rpc_result(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Handles the result for Remote Procedure Calls:
|
Handles the result for Remote Procedure Calls:
|
||||||
|
|
||||||
|
@ -395,19 +390,14 @@ class MTProtoSender:
|
||||||
__log__.debug('Handling RPC result for message {}'.format(message_id))
|
__log__.debug('Handling RPC result for message {}'.format(message_id))
|
||||||
message = self._pending_messages.pop(message_id, None)
|
message = self._pending_messages.pop(message_id, None)
|
||||||
if inner_code == 0x2144ca19: # RPC Error
|
if inner_code == 0x2144ca19: # RPC Error
|
||||||
|
# TODO Report errors if possible/enabled
|
||||||
reader.seek(4)
|
reader.seek(4)
|
||||||
if self.session.report_errors and message:
|
error = rpc_message_to_error(reader.read_int(),
|
||||||
error = rpc_message_to_error(
|
reader.tgread_string())
|
||||||
reader.read_int(), reader.tgread_string(),
|
|
||||||
report_method=type(message.request).CONSTRUCTOR_ID
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
error = rpc_message_to_error(
|
|
||||||
reader.read_int(), reader.tgread_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._send_queue.put(
|
await self._send_queue.put(self.state.create_message(
|
||||||
TLMessage(self.session, MsgsAck([msg_id])))
|
MsgsAck([message.msg_id])
|
||||||
|
))
|
||||||
|
|
||||||
if not message.future.cancelled():
|
if not message.future.cancelled():
|
||||||
message.future.set_exception(error)
|
message.future.set_exception(error)
|
||||||
|
@ -419,7 +409,7 @@ class MTProtoSender:
|
||||||
else:
|
else:
|
||||||
result = message.request.read_result(reader)
|
result = message.request.read_result(reader)
|
||||||
|
|
||||||
self.session.process_entities(result)
|
# TODO Process entities
|
||||||
if not message.future.cancelled():
|
if not message.future.cancelled():
|
||||||
message.future.set_result(result)
|
message.future.set_result(result)
|
||||||
return
|
return
|
||||||
|
@ -428,19 +418,18 @@ class MTProtoSender:
|
||||||
__log__.info('Received response without parent request: {}'
|
__log__.info('Received response without parent request: {}'
|
||||||
.format(reader.tgread_object()))
|
.format(reader.tgread_object()))
|
||||||
|
|
||||||
async def _handle_container(self, msg_id, seq, reader):
|
async def _handle_container(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Processes the inner messages of a container with many of them:
|
Processes the inner messages of a container with many of them:
|
||||||
|
|
||||||
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
|
msg_container#73f1f8dc messages:vector<%Message> = MessageContainer;
|
||||||
"""
|
"""
|
||||||
__log__.debug('Handling container')
|
__log__.debug('Handling container')
|
||||||
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
for inner_message in MessageContainer.iter_read(reader):
|
||||||
next_position = reader.tell_position() + inner_len
|
with BinaryReader(inner_message.body) as inner_reader:
|
||||||
await self._process_message(inner_msg_id, seq, reader)
|
await self._process_message(inner_message, inner_reader)
|
||||||
reader.set_position(next_position) # Ensure reading correctly
|
|
||||||
|
|
||||||
async def _handle_gzip_packed(self, msg_id, seq, reader):
|
async def _handle_gzip_packed(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Unpacks the data from a gzipped object and processes it:
|
Unpacks the data from a gzipped object and processes it:
|
||||||
|
|
||||||
|
@ -448,16 +437,16 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
__log__.debug('Handling gzipped data')
|
__log__.debug('Handling gzipped data')
|
||||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||||
await self._process_message(msg_id, seq, compressed_reader)
|
await self._process_message(message, compressed_reader)
|
||||||
|
|
||||||
async def _handle_update(self, msg_id, seq, reader):
|
async def _handle_update(self, message, reader):
|
||||||
obj = reader.tgread_object()
|
obj = reader.tgread_object()
|
||||||
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
|
__log__.debug('Handling update {}'.format(obj.__class__.__name__))
|
||||||
|
|
||||||
# TODO Further handling of the update
|
# TODO Further handling of the update
|
||||||
self.session.process_entities(obj)
|
# TODO Process entities
|
||||||
|
|
||||||
async def _handle_pong(self, msg_id, seq, reader):
|
async def _handle_pong(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Handles pong results, which don't come inside a ``rpc_result``
|
Handles pong results, which don't come inside a ``rpc_result``
|
||||||
but are still sent through a request:
|
but are still sent through a request:
|
||||||
|
@ -470,7 +459,7 @@ class MTProtoSender:
|
||||||
if message:
|
if message:
|
||||||
message.future.set_result(pong)
|
message.future.set_result(pong)
|
||||||
|
|
||||||
async def _handle_bad_server_salt(self, msg_id, seq, reader):
|
async def _handle_bad_server_salt(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Corrects the currently used server salt to use the right value
|
Corrects the currently used server salt to use the right value
|
||||||
before enqueuing the rejected message to be re-sent:
|
before enqueuing the rejected message to be re-sent:
|
||||||
|
@ -480,11 +469,10 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
__log__.debug('Handling bad salt')
|
__log__.debug('Handling bad salt')
|
||||||
bad_salt = reader.tgread_object()
|
bad_salt = reader.tgread_object()
|
||||||
self.session.salt = bad_salt.new_server_salt
|
self.state.salt = bad_salt.new_server_salt
|
||||||
self.session.save()
|
|
||||||
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
|
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
|
||||||
|
|
||||||
async def _handle_bad_notification(self, msg_id, seq, reader):
|
async def _handle_bad_notification(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Adjusts the current state to be correct based on the
|
Adjusts the current state to be correct based on the
|
||||||
received bad message notification whenever possible:
|
received bad message notification whenever possible:
|
||||||
|
@ -497,14 +485,14 @@ class MTProtoSender:
|
||||||
if bad_msg.error_code in (16, 17):
|
if bad_msg.error_code in (16, 17):
|
||||||
# Sent msg_id too low or too high (respectively).
|
# Sent msg_id too low or too high (respectively).
|
||||||
# Use the current msg_id to determine the right time offset.
|
# Use the current msg_id to determine the right time offset.
|
||||||
self.session.update_time_offset(correct_msg_id=msg_id)
|
self.state.update_time_offset(correct_msg_id=message.msg_id)
|
||||||
elif bad_msg.error_code == 32:
|
elif bad_msg.error_code == 32:
|
||||||
# msg_seqno too low, so just pump it up by some "large" amount
|
# msg_seqno too low, so just pump it up by some "large" amount
|
||||||
# TODO A better fix would be to start with a new fresh session ID
|
# TODO A better fix would be to start with a new fresh session ID
|
||||||
self.session.sequence += 64
|
self.state._sequence += 64
|
||||||
elif bad_msg.error_code == 33:
|
elif bad_msg.error_code == 33:
|
||||||
# msg_seqno too high never seems to happen but just in case
|
# msg_seqno too high never seems to happen but just in case
|
||||||
self.session.sequence -= 16
|
self.state._sequence -= 16
|
||||||
else:
|
else:
|
||||||
msg = self._pending_messages.pop(bad_msg.bad_msg_id, None)
|
msg = self._pending_messages.pop(bad_msg.bad_msg_id, None)
|
||||||
if msg:
|
if msg:
|
||||||
|
@ -514,7 +502,7 @@ class MTProtoSender:
|
||||||
# Messages are to be re-sent once we've corrected the issue
|
# Messages are to be re-sent once we've corrected the issue
|
||||||
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
|
await self._send_queue.put(self._pending_messages[bad_msg.bad_msg_id])
|
||||||
|
|
||||||
async def _handle_detailed_info(self, msg_id, seq, reader):
|
async def _handle_detailed_info(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Updates the current status with the received detailed information:
|
Updates the current status with the received detailed information:
|
||||||
|
|
||||||
|
@ -525,7 +513,7 @@ class MTProtoSender:
|
||||||
__log__.debug('Handling detailed info')
|
__log__.debug('Handling detailed info')
|
||||||
self._pending_ack.add(reader.tgread_object().answer_msg_id)
|
self._pending_ack.add(reader.tgread_object().answer_msg_id)
|
||||||
|
|
||||||
async def _handle_new_detailed_info(self, msg_id, seq, reader):
|
async def _handle_new_detailed_info(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Updates the current status with the received detailed information:
|
Updates the current status with the received detailed information:
|
||||||
|
|
||||||
|
@ -536,7 +524,7 @@ class MTProtoSender:
|
||||||
__log__.debug('Handling new detailed info')
|
__log__.debug('Handling new detailed info')
|
||||||
self._pending_ack.add(reader.tgread_object().answer_msg_id)
|
self._pending_ack.add(reader.tgread_object().answer_msg_id)
|
||||||
|
|
||||||
async def _handle_new_session_created(self, msg_id, seq, reader):
|
async def _handle_new_session_created(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Updates the current status with the received session information:
|
Updates the current status with the received session information:
|
||||||
|
|
||||||
|
@ -545,7 +533,7 @@ class MTProtoSender:
|
||||||
"""
|
"""
|
||||||
# TODO https://goo.gl/LMyN7A
|
# TODO https://goo.gl/LMyN7A
|
||||||
__log__.debug('Handling new session created')
|
__log__.debug('Handling new session created')
|
||||||
self.session.salt = reader.tgread_object().server_salt
|
self.state.salt = reader.tgread_object().server_salt
|
||||||
|
|
||||||
def _clean_containers(self, msg_ids):
|
def _clean_containers(self, msg_ids):
|
||||||
"""
|
"""
|
||||||
|
@ -564,7 +552,7 @@ class MTProtoSender:
|
||||||
del self._pending_messages[message.msg_id]
|
del self._pending_messages[message.msg_id]
|
||||||
break
|
break
|
||||||
|
|
||||||
async def _handle_ack(self, msg_id, seq, reader):
|
async def _handle_ack(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Handles a server acknowledge about our messages. Normally
|
Handles a server acknowledge about our messages. Normally
|
||||||
these can be ignored except in the case of ``auth.logOut``:
|
these can be ignored except in the case of ``auth.logOut``:
|
||||||
|
@ -590,7 +578,7 @@ class MTProtoSender:
|
||||||
del self._pending_messages[msg_id]
|
del self._pending_messages[msg_id]
|
||||||
msg.future.set_result(True)
|
msg.future.set_result(True)
|
||||||
|
|
||||||
async def _handle_future_salts(self, msg_id, seq, reader):
|
async def _handle_future_salts(self, message, reader):
|
||||||
"""
|
"""
|
||||||
Handles future salt results, which don't come inside a
|
Handles future salt results, which don't come inside a
|
||||||
``rpc_result`` but are still sent through a request:
|
``rpc_result`` but are still sent through a request:
|
||||||
|
@ -602,7 +590,7 @@ class MTProtoSender:
|
||||||
# correct one whenever the salt in use expires.
|
# correct one whenever the salt in use expires.
|
||||||
__log__.debug('Handling future salts')
|
__log__.debug('Handling future salts')
|
||||||
salts = reader.tgread_object()
|
salts = reader.tgread_object()
|
||||||
msg = self._pending_messages.pop(msg_id, None)
|
msg = self._pending_messages.pop(message.msg_id, None)
|
||||||
if msg:
|
if msg:
|
||||||
msg.future.set_result(salts)
|
msg.future.set_result(salts)
|
||||||
|
|
||||||
|
|
158
telethon/network/mtprotostate.py
Normal file
158
telethon/network/mtprotostate.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
from hashlib import sha256
|
||||||
|
|
||||||
|
from ..crypto import AES
|
||||||
|
from ..errors import SecurityError, BrokenAuthKeyError
|
||||||
|
from ..extensions import BinaryReader
|
||||||
|
from ..tl import TLMessage
|
||||||
|
|
||||||
|
|
||||||
|
class MTProtoState:
|
||||||
|
"""
|
||||||
|
`telethon.network.mtprotosender.MTProtoSender` needs to hold a state
|
||||||
|
in order to be able to encrypt and decrypt incoming/outgoing messages,
|
||||||
|
as well as generating the message IDs. Instances of this class hold
|
||||||
|
together all the required information.
|
||||||
|
|
||||||
|
It doesn't make sense to use `telethon.sessions.abstract.Session` for
|
||||||
|
the sender because the sender should *not* be concerned about storing
|
||||||
|
this information to disk, as one may create as many senders as they
|
||||||
|
desire to any other data center, or some CDN. Using the same session
|
||||||
|
for all these is not a good idea as each need their own authkey, and
|
||||||
|
the concept of "copying" sessions with the unnecessary entities or
|
||||||
|
updates state for these connections doesn't make sense.
|
||||||
|
"""
|
||||||
|
def __init__(self, auth_key):
|
||||||
|
# Session IDs can be random on every connection
|
||||||
|
self.id = struct.unpack('q', os.urandom(8))[0]
|
||||||
|
self.auth_key = auth_key
|
||||||
|
self.time_offset = 0
|
||||||
|
self.salt = 0
|
||||||
|
self._sequence = 0
|
||||||
|
self._last_msg_id = 0
|
||||||
|
|
||||||
|
def create_message(self, request, after=None):
|
||||||
|
"""
|
||||||
|
Creates a new `telethon.tl.tl_message.TLMessage` from
|
||||||
|
the given `telethon.tl.tlobject.TLObject` instance.
|
||||||
|
"""
|
||||||
|
return TLMessage(
|
||||||
|
msg_id=self._get_new_msg_id(),
|
||||||
|
seq_no=self._get_seq_no(request.content_related),
|
||||||
|
request=request,
|
||||||
|
after_id=after.msg_id if after else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calc_key(auth_key, msg_key, client):
|
||||||
|
"""
|
||||||
|
Calculate the key based on Telegram guidelines for MTProto 2,
|
||||||
|
specifying whether it's the client or not. See
|
||||||
|
https://core.telegram.org/mtproto/description#defining-aes-key-and-initialization-vector
|
||||||
|
"""
|
||||||
|
x = 0 if client else 8
|
||||||
|
sha256a = sha256(msg_key + auth_key[x: x + 36]).digest()
|
||||||
|
sha256b = sha256(auth_key[x + 40:x + 76] + msg_key).digest()
|
||||||
|
|
||||||
|
aes_key = sha256a[:8] + sha256b[8:24] + sha256a[24:32]
|
||||||
|
aes_iv = sha256b[:8] + sha256a[8:24] + sha256b[24:32]
|
||||||
|
|
||||||
|
return aes_key, aes_iv
|
||||||
|
|
||||||
|
def pack_message(self, message):
|
||||||
|
"""
|
||||||
|
Packs the given `telethon.tl.tl_message.TLMessage` using the
|
||||||
|
current authorization key following MTProto 2.0 guidelines.
|
||||||
|
|
||||||
|
See https://core.telegram.org/mtproto/description.
|
||||||
|
"""
|
||||||
|
data = struct.pack('<qq', self.salt, self.id) + bytes(message)
|
||||||
|
padding = os.urandom(-(len(data) + 12) % 16 + 12)
|
||||||
|
|
||||||
|
# Being substr(what, offset, length); x = 0 for client
|
||||||
|
# "msg_key_large = SHA256(substr(auth_key, 88+x, 32) + pt + padding)"
|
||||||
|
msg_key_large = sha256(
|
||||||
|
self.auth_key.key[88:88 + 32] + data + padding).digest()
|
||||||
|
|
||||||
|
# "msg_key = substr (msg_key_large, 8, 16)"
|
||||||
|
msg_key = msg_key_large[8:24]
|
||||||
|
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, True)
|
||||||
|
|
||||||
|
key_id = struct.pack('<Q', self.auth_key.key_id)
|
||||||
|
return (key_id + msg_key +
|
||||||
|
AES.encrypt_ige(data + padding, aes_key, aes_iv))
|
||||||
|
|
||||||
|
def unpack_message(self, body):
|
||||||
|
"""
|
||||||
|
Inverse of `pack_message` for incoming server messages.
|
||||||
|
"""
|
||||||
|
if len(body) < 8:
|
||||||
|
if body == b'l\xfe\xff\xff':
|
||||||
|
raise BrokenAuthKeyError()
|
||||||
|
else:
|
||||||
|
raise BufferError("Can't decode packet ({})".format(body))
|
||||||
|
|
||||||
|
key_id = struct.unpack('<Q', body[:8])[0]
|
||||||
|
if key_id != self.auth_key.key_id:
|
||||||
|
raise SecurityError('Server replied with an invalid auth key')
|
||||||
|
|
||||||
|
msg_key = body[8:24]
|
||||||
|
aes_key, aes_iv = self._calc_key(self.auth_key.key, msg_key, False)
|
||||||
|
data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv))
|
||||||
|
|
||||||
|
data.read_long() # remote_salt
|
||||||
|
if data.read_long() != self.id:
|
||||||
|
raise SecurityError('Server replied with a wrong session ID')
|
||||||
|
|
||||||
|
remote_msg_id = data.read_long()
|
||||||
|
remote_sequence = data.read_int()
|
||||||
|
msg_len = data.read_int()
|
||||||
|
message = data.read(msg_len)
|
||||||
|
|
||||||
|
# https://core.telegram.org/mtproto/security_guidelines
|
||||||
|
# Sections "checking sha256 hash" and "message length"
|
||||||
|
our_key = sha256(self.auth_key.key[96:96 + 32] + data.get_bytes())
|
||||||
|
if msg_key != our_key.digest()[8:24]:
|
||||||
|
raise SecurityError(
|
||||||
|
"Received msg_key doesn't match with expected one")
|
||||||
|
|
||||||
|
return TLMessage(remote_msg_id, remote_sequence, body=message)
|
||||||
|
|
||||||
|
def _get_new_msg_id(self):
|
||||||
|
"""
|
||||||
|
Generates a new unique message ID based on the current
|
||||||
|
time (in ms) since epoch, applying a known time offset.
|
||||||
|
"""
|
||||||
|
now = time.time() + self.time_offset
|
||||||
|
nanoseconds = int((now - int(now)) * 1e+9)
|
||||||
|
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
|
||||||
|
|
||||||
|
if self._last_msg_id >= new_msg_id:
|
||||||
|
new_msg_id = self._last_msg_id + 4
|
||||||
|
|
||||||
|
self._last_msg_id = new_msg_id
|
||||||
|
return new_msg_id
|
||||||
|
|
||||||
|
def update_time_offset(self, correct_msg_id):
|
||||||
|
"""
|
||||||
|
Updates the time offset to the correct
|
||||||
|
one given a known valid message ID.
|
||||||
|
"""
|
||||||
|
now = int(time.time())
|
||||||
|
correct = correct_msg_id >> 32
|
||||||
|
self.time_offset = correct - now
|
||||||
|
self._last_msg_id = 0
|
||||||
|
|
||||||
|
def _get_seq_no(self, content_related):
|
||||||
|
"""
|
||||||
|
Generates the next sequence number depending on whether
|
||||||
|
it should be for a content-related query or not.
|
||||||
|
"""
|
||||||
|
if content_related:
|
||||||
|
result = self._sequence * 2 + 1
|
||||||
|
self._sequence += 1
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return self._sequence * 2
|
|
@ -1,6 +1,7 @@
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from . import TLObject
|
from . import TLObject
|
||||||
|
from .tl_message import TLMessage
|
||||||
|
|
||||||
|
|
||||||
class MessageContainer(TLObject):
|
class MessageContainer(TLObject):
|
||||||
|
@ -33,7 +34,8 @@ class MessageContainer(TLObject):
|
||||||
inner_msg_id = reader.read_long()
|
inner_msg_id = reader.read_long()
|
||||||
inner_sequence = reader.read_int()
|
inner_sequence = reader.read_int()
|
||||||
inner_length = reader.read_int()
|
inner_length = reader.read_int()
|
||||||
yield inner_msg_id, inner_sequence, inner_length
|
yield TLMessage(inner_msg_id, inner_sequence,
|
||||||
|
body=reader.read(inner_length))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return TLObject.pretty_format(self)
|
return TLObject.pretty_format(self)
|
||||||
|
|
|
@ -20,15 +20,23 @@ class TLMessage(TLObject):
|
||||||
sent `TLMessage`, and this result can be represented as a `Future`
|
sent `TLMessage`, and this result can be represented as a `Future`
|
||||||
that will eventually be set with either a result, error or cancelled.
|
that will eventually be set with either a result, error or cancelled.
|
||||||
"""
|
"""
|
||||||
def __init__(self, session, request, after_id=None):
|
def __init__(self, msg_id, seq_no, body=None, request=None, after_id=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
del self.content_related
|
self.msg_id = msg_id
|
||||||
self.msg_id = session.get_new_msg_id()
|
self.seq_no = seq_no
|
||||||
self.seq_no = session.generate_sequence(request.content_related)
|
|
||||||
self.request = request
|
|
||||||
self.container_msg_id = None
|
self.container_msg_id = None
|
||||||
self.future = asyncio.Future()
|
self.future = asyncio.Future()
|
||||||
|
|
||||||
|
# TODO Perhaps it's possible to merge body and request?
|
||||||
|
# We need things like rpc_result and gzip_packed to
|
||||||
|
# be readable by the ``BinaryReader`` for such purpose.
|
||||||
|
|
||||||
|
# Used for incoming, not-decoded messages
|
||||||
|
self.body = body
|
||||||
|
|
||||||
|
# Used for outgoing, not-encoded messages
|
||||||
|
self.request = request
|
||||||
|
|
||||||
# After which message ID this one should run. We do this so
|
# After which message ID this one should run. We do this so
|
||||||
# InvokeAfterMsgRequest is transparent to the user and we can
|
# InvokeAfterMsgRequest is transparent to the user and we can
|
||||||
# easily invoke after while confirming the original request.
|
# easily invoke after while confirming the original request.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user