Telethon/telethon/network/mtproto_sender.py

380 lines
14 KiB
Python
Raw Normal View History

import gzip
from threading import RLock, Thread
from .. import helpers as utils
from ..crypto import AES
from ..errors import BadMessageError, InvalidDCError, rpc_message_to_error
from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck
from ..extensions import BinaryReader, BinaryWriter
import logging
logging.getLogger(__name__).addHandler(logging.NullHandler())
class MtProtoSender:
"""MTProto Mobile Protocol sender
(https://core.telegram.org/mtproto/description)
"""
def __init__(self, connection, session):
"""Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'.
"""
self.connection = connection
2016-08-28 14:43:00 +03:00
self.session = session
self._logger = logging.getLogger(__name__)
self._need_confirmation = [] # Message IDs that need confirmation
self._pending_receive = [] # Requests sent waiting to be received
# Store an RLock instance to make this class safely multi-threaded
self._lock = RLock()
2017-09-04 18:10:04 +03:00
# Used when logging out, the only request that seems to use 'ack'
# TODO There might be a better way to handle msgs_ack requests
2017-04-14 16:28:15 +03:00
self.logging_out = False
2017-02-19 17:20:21 +03:00
# Every unhandled result gets passed to these callbacks, which
# should be functions accepting a single parameter: a TLObject.
# This should only be Update(s), although it can actually be any type.
#
# The thread from which these callbacks are called can be any.
#
# The creator of the MtProtoSender is responsible for setting this
# to point to the list wherever their callbacks reside.
self.unhandled_callbacks = None
def connect(self):
"""Connects to the server"""
self.connection.connect()
2017-02-19 17:20:21 +03:00
def is_connected(self):
return self.connection.is_connected()
def disconnect(self):
"""Disconnects from the server"""
self.connection.close()
2016-08-28 14:43:00 +03:00
# region Send and receive
def send(self, request):
"""Sends the specified MTProtoRequest, previously sending any message
which needed confirmation."""
# Now only us can be using this method
with self._lock:
self._logger.debug('send() acquired the lock')
# If any message needs confirmation send an AckRequest first
self._send_acknowledges()
# Finally send our packed request
with BinaryWriter() as writer:
request.on_send(writer)
self._send_packet(writer.get_bytes(), request)
self._pending_receive.append(request)
# And update the saved session
self.session.save()
2016-08-28 14:43:00 +03:00
self._logger.debug('send() released the lock')
2017-04-09 14:14:04 +03:00
def _send_acknowledges(self):
"""Sends a messages acknowledge for all those who _need_confirmation"""
if self._need_confirmation:
msgs_ack = MsgsAck(self._need_confirmation)
with BinaryWriter() as writer:
msgs_ack.on_send(writer)
self._send_packet(writer.get_bytes(), msgs_ack)
del self._need_confirmation[:]
def receive(self):
"""Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts
of the MtProtoSender such as the updates callback being fired
or a pending request being confirmed.
"""
# TODO Don't ignore updates
self._logger.debug('Receiving a message...')
body = self.connection.recv()
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
self._process_msg(
remote_msg_id, remote_seq, reader, updates=None)
self._logger.debug('Received message.')
2016-08-28 14:43:00 +03:00
# endregion
# region Low level processing
def _send_packet(self, packet, request):
"""Sends the given packet bytes with the additional
2017-09-04 18:10:04 +03:00
information of the original request.
This does NOT lock the threads!
"""
request.request_msg_id = self.session.get_new_msg_id()
# First calculate plain_text to encrypt it
with BinaryWriter() as plain_writer:
plain_writer.write_long(self.session.salt, signed=False)
plain_writer.write_long(self.session.id, signed=False)
plain_writer.write_long(request.request_msg_id)
plain_writer.write_int(
self.session.generate_sequence(request.content_related))
plain_writer.write_int(len(packet))
plain_writer.write(packet)
msg_key = utils.calc_msg_key(plain_writer.get_bytes())
2016-09-04 13:42:11 +03:00
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True)
cipher_text = AES.encrypt_ige(plain_writer.get_bytes(), key, iv)
# And then finally send the encrypted packet
with BinaryWriter() as cipher_writer:
2016-11-30 00:29:42 +03:00
cipher_writer.write_long(
self.session.auth_key.key_id, signed=False)
cipher_writer.write(msg_key)
cipher_writer.write(cipher_text)
self.connection.send(cipher_writer.get_bytes())
def _decode_msg(self, body):
2016-08-28 14:43:00 +03:00
"""Decodes an received encrypted message body bytes"""
message = None
2016-08-28 14:43:00 +03:00
remote_msg_id = None
remote_sequence = None
with BinaryReader(body) as reader:
if len(body) < 8:
raise BufferError("Can't decode packet ({})".format(body))
# TODO Check for both auth key ID and msg_key correctness
reader.read_long() # remote_auth_key_id
msg_key = reader.read(16)
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, False)
2016-11-30 00:29:42 +03:00
plain_text = AES.decrypt_ige(
reader.read(len(body) - reader.tell_position()), key, iv)
with BinaryReader(plain_text) as plain_text_reader:
plain_text_reader.read_long() # remote_salt
plain_text_reader.read_long() # remote_session_id
2016-08-28 14:43:00 +03:00
remote_msg_id = plain_text_reader.read_long()
remote_sequence = plain_text_reader.read_int()
msg_len = plain_text_reader.read_int()
message = plain_text_reader.read(msg_len)
2016-08-28 14:43:00 +03:00
return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, updates):
"""Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't
need to be skipped. Returns False otherwise.
"""
# TODO Check salt, session_id and sequence_number
self._need_confirmation.append(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
# The following codes are "parsed manually"
2017-09-04 18:10:04 +03:00
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong
return self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, updates)
if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, updates)
if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification
return self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted
if code == 0x62d6b459:
2016-09-26 14:13:11 +03:00
ack = reader.tgread_object()
for r in self._pending_receive:
if r.request_msg_id in ack.msg_ids:
self._logger.debug('Ack found for the a request')
if self.logging_out:
self._logger.debug('Message ack confirmed a request')
self._pending_receive.remove(r)
r.confirm_received.set()
return True
2016-09-26 14:13:11 +03:00
# If the code is not parsed manually then it should be a TLObject.
if code in tlobjects:
result = reader.tgread_object()
if self.unhandled_callbacks:
2017-09-04 18:10:04 +03:00
self._logger.debug(
'Passing TLObject to callbacks %s', repr(result)
)
for callback in self.unhandled_callbacks:
callback(result)
else:
2017-09-04 18:10:04 +03:00
self._logger.debug(
'Ignoring unhandled TLObject %s', repr(result)
)
return True
self._logger.debug('Unknown message: {}'.format(hex(code)))
return False
2016-08-28 14:43:00 +03:00
# endregion
# region Message handling
def _pop_request(self, request_msg_id):
"""Pops a pending request from self._pending_receive, or
returns None if it's not found
"""
for i in range(len(self._pending_receive)):
if self._pending_receive[i].request_msg_id == request_msg_id:
return self._pending_receive.pop(i)
def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
reader.read_int(signed=False) # code
received_msg_id = reader.read_long()
request = self._pop_request(received_msg_id)
if request:
self._logger.debug('Pong confirmed a request')
request.confirm_received.set()
return True
def _handle_container(self, msg_id, sequence, reader, updates):
self._logger.debug('Handling container')
reader.read_int(signed=False) # code
size = reader.read_int()
for _ in range(size):
inner_msg_id = reader.read_long()
reader.read_int() # inner_sequence
inner_length = reader.read_int()
begin_position = reader.tell_position()
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not self._process_msg(
inner_msg_id, sequence, reader, updates):
reader.set_position(begin_position + inner_length)
except:
# If any error is raised, something went wrong; skip the packet
reader.set_position(begin_position + inner_length)
raise
return True
def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code
bad_msg_id = reader.read_long()
reader.read_int() # bad_msg_seq_no
reader.read_int() # error_code
new_salt = reader.read_long(signed=False)
2016-08-28 14:43:00 +03:00
self.session.salt = new_salt
request = self._pop_request(bad_msg_id)
if request:
self.send(request)
return True
def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code
reader.read_long() # request_id
reader.read_int() # request_sequence
error_code = reader.read_int()
error = BadMessageError(error_code)
if error_code in (16, 17):
# sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
self.session.save()
self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.')
return True
else:
raise error
def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code
request_id = reader.read_long()
inner_code = reader.read_int(signed=False)
request = self._pop_request(request_id)
if inner_code == 0x2144ca19: # RPC Error
if self.session.report_errors and request:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(request).constructor_id
)
else:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string()
)
2017-04-09 14:14:04 +03:00
# Acknowledge that we received the error
self._need_confirmation.append(request_id)
self._send_acknowledges()
if request:
request.rpc_error = error
request.confirm_received.set()
# else TODO Where should this error be reported?
# Read may be async. Can an error not-belong to a request?
self._logger.debug('Read RPC error: %s', str(error))
else:
if request:
self._logger.debug('Reading request response')
if inner_code == 0x3072cfa1: # GZip packed
unpacked_data = gzip.decompress(reader.tgread_bytes())
with BinaryReader(unpacked_data) as compressed_reader:
request.on_response(compressed_reader)
else:
reader.seek(-4)
request.on_response(reader)
request.confirm_received.set()
return True
else:
# If it's really a result for RPC from previous connection
# session, it will be skipped by the handle_container()
self._logger.debug('Lost request will be skipped.')
return False
def _handle_gzip_packed(self, msg_id, sequence, reader, updates):
self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader:
return self._process_msg(
msg_id, sequence, compressed_reader, updates)
2016-08-28 14:43:00 +03:00
# endregion