Create a proper Message class (msg_id, seqno, body; only .to_bytes())

This commit is contained in:
Lonami Exo 2017-09-27 21:01:20 +02:00
parent b0839a028e
commit bd3dd371a2
4 changed files with 67 additions and 72 deletions

View File

@ -1,5 +1,6 @@
import gzip import gzip
import logging import logging
import struct
from threading import RLock from threading import RLock
from .. import helpers as utils from .. import helpers as utils
@ -8,8 +9,8 @@ from ..errors import (
BadMessageError, InvalidChecksumError, BrokenAuthKeyError, BadMessageError, InvalidChecksumError, BrokenAuthKeyError,
rpc_message_to_error rpc_message_to_error
) )
from ..extensions import BinaryReader, BinaryWriter from ..extensions import BinaryReader
from ..tl import MessageContainer, GzipPacked from ..tl import Message, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck from ..tl.types import MsgsAck
@ -29,8 +30,11 @@ class MtProtoSender:
self.connection = connection self.connection = connection
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
self._need_confirmation = [] # Message IDs that need confirmation # Message IDs that need confirmation
self._pending_receive = [] # Requests sent waiting to be received self._need_confirmation = []
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
# Sending and receiving are independent, but two threads cannot # Sending and receiving are independent, but two threads cannot
# send or receive at the same time no matter what. # send or receive at the same time no matter what.
@ -65,21 +69,22 @@ class MtProtoSender:
self._send_acknowledges() self._send_acknowledges()
# Finally send our packed request(s) # Finally send our packed request(s)
self._pending_receive.extend(requests) messages = [Message(self.session, r) for r in requests]
if len(requests) == 1: self._pending_receive.update({m.msg_id: m for m in messages})
request = requests[0]
data = GzipPacked.gzip_if_smaller(request)
else:
request = MessageContainer(self.session, requests)
data = request.to_bytes()
self._send_packet(data, request) if len(messages) == 1:
message = messages[0]
else:
message = Message(self.session, MessageContainer(messages))
self._send_message(message)
def _send_acknowledges(self): def _send_acknowledges(self):
"""Sends a messages acknowledge for all those who _need_confirmation""" """Sends a messages acknowledge for all those who _need_confirmation"""
if self._need_confirmation: if self._need_confirmation:
msgs_ack = MsgsAck(self._need_confirmation) self._send_message(
self._send_packet(msgs_ack.to_bytes(), msgs_ack) Message(self.session, MsgsAck(self._need_confirmation))
)
del self._need_confirmation[:] del self._need_confirmation[:]
def receive(self, update_state): def receive(self, update_state):
@ -114,36 +119,21 @@ class MtProtoSender:
# region Low level processing # region Low level processing
def _send_packet(self, packet, request): def _send_message(self, message):
"""Sends the given packet bytes with the additional """Sends the given Message(TLObject) encrypted through the network"""
information of the original request.
"""
request.request_msg_id = self.session.get_new_msg_id()
# First calculate plain_text to encrypt it plain_text = \
with BinaryWriter() as plain_writer: struct.pack('<QQ', self.session.salt, self.session.id) \
plain_writer.write_long(self.session.salt, signed=False) + message.to_bytes()
plain_writer.write_long(self.session.id, signed=False)
plain_writer.write_long(request.request_msg_id)
plain_writer.write_int(
self.session.generate_sequence(request.content_related))
plain_writer.write_int(len(packet)) msg_key = utils.calc_msg_key(plain_text)
plain_writer.write(packet) key_id = struct.pack('<q', self.session.auth_key.key_id)
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True)
cipher_text = AES.encrypt_ige(plain_text, key, iv)
msg_key = utils.calc_msg_key(plain_writer.get_bytes()) result = key_id + msg_key + cipher_text
with self._send_lock:
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True) self.connection.send(result)
cipher_text = AES.encrypt_ige(plain_writer.get_bytes(), key, iv)
# And then finally send the encrypted packet
with BinaryWriter() as cipher_writer:
cipher_writer.write_long(
self.session.auth_key.key_id, signed=False)
cipher_writer.write(msg_key)
cipher_writer.write(cipher_text)
with self._send_lock:
self.connection.send(cipher_writer.get_bytes())
def _decode_msg(self, body): def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes""" """Decodes an received encrypted message body bytes"""
@ -211,13 +201,12 @@ class MtProtoSender:
# msgs_ack, it may handle the request we wanted # msgs_ack, it may handle the request we wanted
if code == 0x62d6b459: if code == 0x62d6b459:
ack = reader.tgread_object() ack = reader.tgread_object()
for r in self._pending_receive: # We only care about ack requests if we're logging out
if r.request_msg_id in ack.msg_ids: if self.logging_out:
self._logger.debug('Ack found for the a request') for msg_id in ack.msg_ids:
r = self._pop_request(msg_id)
if self.logging_out: if r:
self._logger.debug('Message ack confirmed a request') self._logger.debug('Message ack confirmed', r)
self._pending_receive.remove(r)
r.confirm_received.set() r.confirm_received.set()
return True return True
@ -244,16 +233,16 @@ class MtProtoSender:
# region Message handling # region Message handling
def _pop_request(self, request_msg_id): def _pop_request(self, msg_id):
"""Pops a pending request from self._pending_receive, or """Pops a pending REQUEST from self._pending_receive, or
returns None if it's not found returns None if it's not found.
""" """
for i in range(len(self._pending_receive)): message = self._pending_receive.pop(msg_id, None)
if self._pending_receive[i].request_msg_id == request_msg_id: if message:
return self._pending_receive.pop(i) return message.request
def _clear_all_pending(self): def _clear_all_pending(self):
for r in self._pending_receive: for r in self._pending_receive.values():
r.confirm_received.set() r.confirm_received.set()
self._pending_receive.clear() self._pending_receive.clear()

View File

@ -1,4 +1,5 @@
from .tlobject import TLObject from .tlobject import TLObject
from .session import Session from .session import Session
from .gzip_packed import GzipPacked from .gzip_packed import GzipPacked
from .message import Message
from .message_container import MessageContainer from .message_container import MessageContainer

17
telethon/tl/message.py Normal file
View File

@ -0,0 +1,17 @@
import struct
from . import TLObject, GzipPacked
class Message(TLObject):
"""https://core.telegram.org/mtproto/service_messages#simple-container"""
def __init__(self, session, request):
super().__init__()
del self.content_related
self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related)
self.request = request
def to_bytes(self):
body = GzipPacked.gzip_if_smaller(self.request)
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body

View File

@ -5,31 +5,19 @@ from ..extensions import BinaryWriter
class MessageContainer(TLObject): class MessageContainer(TLObject):
constructor_id = 0x73f1f8dc constructor_id = 0x73f1f8dc
# TODO Currently it's a bit of a hack, since the container actually holds def __init__(self, messages):
# messages (message id, sequence number, request body), not requests.
# Probably create a proper "Message" class
def __init__(self, session, requests):
super().__init__() super().__init__()
self.content_related = False self.content_related = False
self.session = session self.messages = messages
self.requests = requests
def to_bytes(self): def to_bytes(self):
# TODO Change this to delete the on_send from this class # TODO Change this to delete the on_send from this class
with BinaryWriter() as writer: with BinaryWriter() as writer:
writer.write_int(MessageContainer.constructor_id, signed=False) writer.write_int(MessageContainer.constructor_id, signed=False)
writer.write_int(len(self.requests)) writer.write_int(len(self.messages))
for x in self.requests: for m in self.messages:
x.request_msg_id = self.session.get_new_msg_id() writer.write(m.to_bytes())
writer.write_long(x.request_msg_id)
writer.write_int(
self.session.generate_sequence(x.content_related)
)
packet = GzipPacked.gzip_if_smaller(x)
writer.write_int(len(packet))
writer.write(packet)
return writer.get_bytes() return writer.get_bytes()
@staticmethod @staticmethod