Make generate_sequence() thread-safe and move it to Session

This commit is contained in:
Lonami Exo 2017-06-07 12:00:01 +01:00
parent 49ca5c00c7
commit 1860054ec0
2 changed files with 36 additions and 15 deletions

View File

@ -39,16 +39,6 @@ class MtProtoSender:
"""Disconnects from the server""" """Disconnects from the server"""
self._transport.close() self._transport.close()
def _generate_sequence(self, confirmed):
"""Generates the next sequence number, based on whether it
was confirmed yet or not"""
if confirmed:
result = self.session.sequence * 2 + 1
self.session.sequence += 1
return result
else:
return self.session.sequence * 2
# region Send and receive # region Send and receive
def send(self, request): def send(self, request):
@ -142,7 +132,9 @@ class MtProtoSender:
plain_writer.write_long(self.session.salt, signed=False) plain_writer.write_long(self.session.salt, signed=False)
plain_writer.write_long(self.session.id, signed=False) plain_writer.write_long(self.session.id, signed=False)
plain_writer.write_long(request.msg_id) plain_writer.write_long(request.msg_id)
plain_writer.write_int(self._generate_sequence(request.confirmed)) plain_writer.write_int(
self.session.generate_sequence(request.confirmed))
plain_writer.write_int(len(packet)) plain_writer.write_int(len(packet))
plain_writer.write(packet) plain_writer.write(packet)

View File

@ -3,6 +3,7 @@ import os
import pickle import pickle
import random import random
import time import time
from threading import Lock
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
@ -52,6 +53,16 @@ class Session:
else: else:
return Session(session_user_id) return Session(session_user_id)
def generate_sequence(self, confirmed):
"""Ported from JsonSession.generate_sequence"""
with Lock():
if confirmed:
result = self.sequence * 2 + 1
self.sequence += 1
return result
else:
return self.sequence * 2
def get_new_msg_id(self): def get_new_msg_id(self):
"""Generates a new message ID based on the current time (in ms) since epoch""" """Generates a new message ID based on the current time (in ms) since epoch"""
# Refer to mtproto_plain_sender.py for the original method, this is a simple copy # Refer to mtproto_plain_sender.py for the original method, this is a simple copy
@ -92,11 +103,14 @@ class JsonSession:
self.port = 443 self.port = 443
self.auth_key = None self.auth_key = None
self.id = utils.generate_random_long(signed=False) self.id = utils.generate_random_long(signed=False)
self.sequence = 0 self._sequence = 0
self.salt = 0 # Unsigned long self.salt = 0 # Unsigned long
self.time_offset = 0 self.time_offset = 0
self.last_message_id = 0 # Long self.last_message_id = 0 # Long
# Cross-thread safety
self._lock = Lock()
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
if self.session_user_id: if self.session_user_id:
@ -105,7 +119,7 @@ class JsonSession:
'id': self.id, 'id': self.id,
'port': self.port, 'port': self.port,
'salt': self.salt, 'salt': self.salt,
'sequence': self.sequence, 'sequence': self._sequence,
'time_offset': self.time_offset, 'time_offset': self.time_offset,
'server_address': self.server_address, 'server_address': self.server_address,
'auth_key_data': 'auth_key_data':
@ -139,7 +153,7 @@ class JsonSession:
result.id = data['id'] result.id = data['id']
result.port = data['port'] result.port = data['port']
result.salt = data['salt'] result.salt = data['salt']
result.sequence = data['sequence'] result._sequence = data['sequence']
result.time_offset = data['time_offset'] result.time_offset = data['time_offset']
result.server_address = data['server_address'] result.server_address = data['server_address']
@ -155,7 +169,7 @@ class JsonSession:
result.id = old.id result.id = old.id
result.port = old.port result.port = old.port
result.salt = old.salt result.salt = old.salt
result.sequence = old.sequence result._sequence = old.sequence
result.time_offset = old.time_offset result.time_offset = old.time_offset
result.server_address = old.server_address result.server_address = old.server_address
result.auth_key = old.auth_key result.auth_key = old.auth_key
@ -163,6 +177,21 @@ class JsonSession:
return result return result
def generate_sequence(self, confirmed):
"""Thread safe method to generates the next sequence number,
based on whether it was confirmed yet or not.
Note that if confirmed=True, the sequence number
will be increased by one too
"""
with self._lock:
if confirmed:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
def get_new_msg_id(self): def get_new_msg_id(self):
"""Generates a new unique message ID based on the current """Generates a new unique message ID based on the current
time (in ms) since epoch""" time (in ms) since epoch"""