Make generate_sequence() thread-safe and move it to Session

This commit is contained in:
Lonami Exo 2017-06-07 12:00:01 +01:00
parent 49ca5c00c7
commit 1860054ec0
2 changed files with 36 additions and 15 deletions

View File

@ -39,16 +39,6 @@ class MtProtoSender:
"""Disconnects from the server"""
self._transport.close()
def _generate_sequence(self, confirmed):
"""Generates the next sequence number, based on whether it
was confirmed yet or not"""
if confirmed:
result = self.session.sequence * 2 + 1
self.session.sequence += 1
return result
else:
return self.session.sequence * 2
# region Send and receive
def send(self, request):
@ -142,7 +132,9 @@ class MtProtoSender:
plain_writer.write_long(self.session.salt, signed=False)
plain_writer.write_long(self.session.id, signed=False)
plain_writer.write_long(request.msg_id)
plain_writer.write_int(self._generate_sequence(request.confirmed))
plain_writer.write_int(
self.session.generate_sequence(request.confirmed))
plain_writer.write_int(len(packet))
plain_writer.write(packet)

View File

@ -3,6 +3,7 @@ import os
import pickle
import random
import time
from threading import Lock
from base64 import b64encode, b64decode
from os.path import isfile as file_exists
@ -52,6 +53,16 @@ class Session:
else:
return Session(session_user_id)
def generate_sequence(self, confirmed):
"""Ported from JsonSession.generate_sequence"""
with Lock():
if confirmed:
result = self.sequence * 2 + 1
self.sequence += 1
return result
else:
return self.sequence * 2
def get_new_msg_id(self):
"""Generates a new message ID based on the current time (in ms) since epoch"""
# Refer to mtproto_plain_sender.py for the original method, this is a simple copy
@ -92,11 +103,14 @@ class JsonSession:
self.port = 443
self.auth_key = None
self.id = utils.generate_random_long(signed=False)
self.sequence = 0
self._sequence = 0
self.salt = 0 # Unsigned long
self.time_offset = 0
self.last_message_id = 0 # Long
# Cross-thread safety
self._lock = Lock()
def save(self):
"""Saves the current session object as session_user_id.session"""
if self.session_user_id:
@ -105,7 +119,7 @@ class JsonSession:
'id': self.id,
'port': self.port,
'salt': self.salt,
'sequence': self.sequence,
'sequence': self._sequence,
'time_offset': self.time_offset,
'server_address': self.server_address,
'auth_key_data':
@ -139,7 +153,7 @@ class JsonSession:
result.id = data['id']
result.port = data['port']
result.salt = data['salt']
result.sequence = data['sequence']
result._sequence = data['sequence']
result.time_offset = data['time_offset']
result.server_address = data['server_address']
@ -155,7 +169,7 @@ class JsonSession:
result.id = old.id
result.port = old.port
result.salt = old.salt
result.sequence = old.sequence
result._sequence = old.sequence
result.time_offset = old.time_offset
result.server_address = old.server_address
result.auth_key = old.auth_key
@ -163,6 +177,21 @@ class JsonSession:
return result
def generate_sequence(self, confirmed):
"""Thread safe method to generates the next sequence number,
based on whether it was confirmed yet or not.
Note that if confirmed=True, the sequence number
will be increased by one too
"""
with self._lock:
if confirmed:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""