diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index 84dac60c..35f065ab 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -39,16 +39,6 @@ class MtProtoSender: """Disconnects from the server""" self._transport.close() - def _generate_sequence(self, confirmed): - """Generates the next sequence number, based on whether it - was confirmed yet or not""" - if confirmed: - result = self.session.sequence * 2 + 1 - self.session.sequence += 1 - return result - else: - return self.session.sequence * 2 - # region Send and receive def send(self, request): @@ -142,7 +132,9 @@ class MtProtoSender: plain_writer.write_long(self.session.salt, signed=False) plain_writer.write_long(self.session.id, signed=False) plain_writer.write_long(request.msg_id) - plain_writer.write_int(self._generate_sequence(request.confirmed)) + plain_writer.write_int( + self.session.generate_sequence(request.confirmed)) + plain_writer.write_int(len(packet)) plain_writer.write(packet) diff --git a/telethon/tl/session.py b/telethon/tl/session.py index 14140648..a451b9f1 100644 --- a/telethon/tl/session.py +++ b/telethon/tl/session.py @@ -3,6 +3,7 @@ import os import pickle import random import time +from threading import Lock from base64 import b64encode, b64decode from os.path import isfile as file_exists @@ -52,6 +53,16 @@ class Session: else: return Session(session_user_id) + def generate_sequence(self, confirmed): + """Ported from JsonSession.generate_sequence""" + with Lock(): + if confirmed: + result = self.sequence * 2 + 1 + self.sequence += 1 + return result + else: + return self.sequence * 2 + def get_new_msg_id(self): """Generates a new message ID based on the current time (in ms) since epoch""" # Refer to mtproto_plain_sender.py for the original method, this is a simple copy @@ -92,11 +103,14 @@ class JsonSession: self.port = 443 self.auth_key = None self.id = utils.generate_random_long(signed=False) - self.sequence = 0 + self._sequence = 0 self.salt = 0 # Unsigned long self.time_offset = 0 self.last_message_id = 0 # Long + # Cross-thread safety + self._lock = Lock() + def save(self): """Saves the current session object as session_user_id.session""" if self.session_user_id: @@ -105,7 +119,7 @@ class JsonSession: 'id': self.id, 'port': self.port, 'salt': self.salt, - 'sequence': self.sequence, + 'sequence': self._sequence, 'time_offset': self.time_offset, 'server_address': self.server_address, 'auth_key_data': @@ -139,7 +153,7 @@ class JsonSession: result.id = data['id'] result.port = data['port'] result.salt = data['salt'] - result.sequence = data['sequence'] + result._sequence = data['sequence'] result.time_offset = data['time_offset'] result.server_address = data['server_address'] @@ -155,7 +169,7 @@ class JsonSession: result.id = old.id result.port = old.port result.salt = old.salt - result.sequence = old.sequence + result._sequence = old.sequence result.time_offset = old.time_offset result.server_address = old.server_address result.auth_key = old.auth_key @@ -163,6 +177,21 @@ class JsonSession: return result + def generate_sequence(self, confirmed): + """Thread safe method to generates the next sequence number, + based on whether it was confirmed yet or not. + + Note that if confirmed=True, the sequence number + will be increased by one too + """ + with self._lock: + if confirmed: + result = self._sequence * 2 + 1 + self._sequence += 1 + return result + else: + return self._sequence * 2 + def get_new_msg_id(self): """Generates a new unique message ID based on the current time (in ms) since epoch"""